

#include "main.h"

#include <pcap.h>

/******************************************************************************/


typedef struct ip_hdr_t
{
	u8  h_len:4;         // length of header
	u8  version:4;       // Version of IP
	u8  tos;             // Type of service
	u16 total_len;       // total length of the packet

	u16 ident;           // unique identifier
	u16 frag_and_flags;  // flags

	u8  ttl;             // ttl
	u8  proto;           // protocol(TCP ,UDP etc)
	u16 checksum;        // IP checksum

	u32 src_ip;
	u32 dst_ip;
}ip_hdr;


typedef struct icmp_hdr_t
{
	u8  type;
	u8  code;
	u16 cksum;
	u16 id;
	u16 seq;
}icmp_hdr;


typedef struct udp_hdr_t
{
	u16 src_port;
	u16 dst_port;
	u16 total_len;
	u16 cksum;
}udp_hdr;


typedef struct rawio_tunnel_t {
	// 远端ip与port
	struct sockaddr_in baddr;

	int timeout;

	// for icmp
	int id;
	int seq;

	// for kcp
	void *ktun;
}rawio_tunnel;

#define TUNNEL_TIMEOUT 10

// 用于icmp收包
static pcap_t *pcap_fp;
// rawio端口
static SOCKET rawio_socket;


// 客户端使用
static rawio_tunnel *client_tun;
static int client_state;

// 服务端使用
#define MAX_TUNNEL 64
static rawio_tunnel *server_tun[MAX_TUNNEL];
// 服务端反向代理时使用
static rawio_tunnel *listen_tun;


static pthread_t recv_tid;
static void (*stream_recv_cb)(void *tun, const void *buf, int len);
static void *stream_recv_arg;


static u32 magic_pkt[2] = {0x76057810, 0x37219527};

/******************************************************************************/


static int npcap_init(char *devname)
{
	pcap_if_t *alldevs;
	pcap_if_t *d = NULL;
	char errbuf[PCAP_ERRBUF_SIZE];
	int found = 0;
	struct in_addr iaddr, daddr;

	printf("npcap_init ...\n");

	if(devname==NULL){
		printf("No interface ip!\n");
		return -1;
	}
	iaddr.s_addr = inet_addr(devname);


	/* Retrieve the device list */
	if(pcap_findalldevs(&alldevs, errbuf) == -1){
		fprintf(stderr,"Error in pcap_findalldevs: %s\n", errbuf);
		exit(1);
	}

	for(d=alldevs; d; d=d->next){
		pcap_addr_t *addr =d->addresses;
		while(addr){
			daddr = ((struct sockaddr_in*)addr->addr)->sin_addr;
			if(iaddr.s_addr == daddr.s_addr){
				printf("Use %s (%s)\n", d->name, d->description);
				found = 1;
				break;
			}
			addr = addr->next;
		}
		if(found)
			break;
	}
	if(d==NULL){
		printf("No interface found(%s)!\n", devname);
		return -1;
	}

	pcap_fp = pcap_open_live(d->name, 65536, 0, 10, errbuf);
	if(pcap_fp==NULL){
		printf("\nUnable to open the adapter!\n");
		return -3;
	}

	struct bpf_program fcode;
	char filter[64];

	if(gcfg.run_mode==CLIENT_MODE)
		sprintf(filter, "icmp[0] = 0x%02x and icmp[1] = 0x70", SERVER_PROTO);
	else
		sprintf(filter, "icmp[0] = 0x%02x and icmp[1] = 0x70", CLIENT_PROTO);

	if(pcap_compile(pcap_fp, &fcode, filter, 1, PCAP_NETMASK_UNKNOWN)<0){
		printf("\nError compiling filter: wrong syntax.\n");
		pcap_close(pcap_fp);
		return -4;
	}
	pcap_setfilter(pcap_fp, &fcode); 


	if(gcfg.run_mode==SERVER_MODE){
		int retv;
		printf("Check firewall rules ...\n");
#if defined(__linux__)
		retv = system("iptables -C INPUT -p icmp --icmp-type 0x08/0x70 -j DROP");
		if(retv){
			retv = system("iptables -A INPUT -p icmp --icmp-type 0x08/0x70 -j DROP");
			if(WEXITSTATUS(retv)){
				printf("\n  Add iptable rules failed!\n");
				exit(-1);
			}else{
				printf("Add firewall rule done.\n");
			}
		}
#elif defined(__WIN32__)
		retv = system("netsh advfirewall firewall show rule name=kkicmp");
		if(retv){
			retv = system("netsh advfirewall firewall add rule name=kkicmp protocol=icmpv4:8,112 dir=in action=block");
			if(retv){
				printf("Please add firewall rule with Administrator:\n");
				printf("  netsh advfirewall firewall add rule name=kkicmp protocol=icmpv4:8,112 dir=in action=block\n");
				exit(-1);
			}else{
				printf("Add firewall rule done.\n");
			}
		}
#endif
	}

	return 0;
}


/******************************************************************************/


char *rawio_tunnel_type(void)
{
	if(gcfg.raw_mode==RAWMODE_ICMP)
		return "ICMP";
	else if(gcfg.raw_mode==RAWMODE_UDP)
		return "UDP";
	else if(gcfg.raw_mode==RAWMODE_FAKETCP)
		return "FAKETCP";
	else
		return "Unknow";

}


static rawio_tunnel *new_rawio_tunnel(u32 peer_addr, int id)
{
	rawio_tunnel *tun;


	tun = (rawio_tunnel*)malloc(sizeof(rawio_tunnel));

	tun->baddr.sin_family = AF_INET;
	if(gcfg.raw_mode!=RAWMODE_ICMP){
		tun->baddr.sin_port = htons(id);
	}else{
		tun->baddr.sin_port = 0;
	}
	tun->baddr.sin_addr.s_addr = peer_addr;
	tun->id = id;
	tun->seq = 0;

	tun->ktun = new_kcp_tunnel(tun, stream_recv_cb, stream_recv_arg);

	return tun;
}


static void add_tunnel(rawio_tunnel *tun)
{
	int i;

	for(i=0; i<MAX_TUNNEL; i++){
		if(server_tun[i]==NULL){
			server_tun[i] = tun;
			return;
		}
	}

	printf("Too many tunnel!\n");
}


static rawio_tunnel *find_tunnel(u32 paddr, int id)
{
	int i;

	for(i=0; i<MAX_TUNNEL; i++){
		if(server_tun[i]==NULL)
			continue;
		if((server_tun[i]->baddr.sin_addr.s_addr == paddr) && (server_tun[i]->id == id))
			return server_tun[i];
	}

	return NULL;
}


static void close_tunnel(rawio_tunnel *tun)
{
	printf("Close %s tunnel from %s:%d\n", rawio_tunnel_type(), inet_ntoa(tun->baddr.sin_addr), ntohs(tun->baddr.sin_port));

	close_tunnel_stream(tun->ktun);
	close_kcp_tunnel(tun->ktun);

	free(tun);
}


void *rawio_listen_tun(void)
{
	if(listen_tun)
		return listen_tun->ktun;

	return NULL;
}


/******************************************************************************/


static void rawio_input(u32 src_addr, u32 src_port, u8 *buf, int len)
{
	rawio_tunnel *tun;
	struct in_addr saddr;
	saddr.s_addr = src_addr;

	//hex_dump("rawio_input", buf, len);

	u32 *ph = (u32*)buf;
	if(len==sizeof(magic_pkt) && ph[0]==magic_pkt[0] && ph[1]==magic_pkt[1]){
		// 长度为0的包是connect包.
		if(gcfg.run_mode==SERVER_MODE){
			tun = find_tunnel(src_addr, src_port);
			if(tun==NULL){
				tun = new_rawio_tunnel(src_addr, src_port);
				add_tunnel(tun);
				printf("New %s tunnel from %s:%d\n", rawio_tunnel_type(), inet_ntoa(saddr), src_port);
			}
			// 服务端回复客户端
			rawio_send(tun, NULL, 0);
			tun->timeout = 0;

			if(gcfg.listen_port){
				// 服务端反向代理
				listen_tun = tun;
			}
		}else{
			// 客户端收到回复
			if(client_state==0){
				printf("New %s tunnel to %s success.\n", rawio_tunnel_type(), gcfg.target_ip);
				client_state = 1;
			}
			client_tun->timeout = 0;
		}
		return;
	}

	//printf("rawio_input: len=%d\n", len);

	if(gcfg.run_mode==SERVER_MODE){
		tun = find_tunnel(src_addr, src_port);
	}else{
		tun = client_tun;
	}
	tun->timeout = 0;

	if(tun==NULL){
		printf("Unknow %s packet: %s:%d\n", rawio_tunnel_type(), inet_ntoa(saddr), src_port);
		return;
	}

	kcp_input(tun->ktun, buf, len);
}


void handle_udp_recv(void *tunnel)
{
	struct sockaddr_in saddr;
	socklen_t alen = sizeof(saddr);
	u8 pbuf[1536];
	int retv;

	retv = recvfrom(rawio_socket, (char*)pbuf, sizeof(pbuf), 0, (struct sockaddr *)&saddr, &alen);
	if(retv<0){
		printf("udp_tunnel recv failed! %d\n", WSAGetLastError());
		return;
	}

	rawio_input(saddr.sin_addr.s_addr, ntohs(saddr.sin_port), pbuf, retv);
}

/******************************************************************************/


static void npcap_handler(u_char *user, const struct pcap_pkthdr *h, const u_char *bytes)
{
	ip_hdr *ip = (ip_hdr*)(bytes+14);
	icmp_hdr *icmp = (icmp_hdr*)(bytes+14+20);
	int len = ntohs(ip->total_len) - (20+8);

	rawio_input(ip->src_ip, icmp->id, (u8*)bytes+(14+20+8), len);
}


static void *npcap_recv_thread(void *arg)
{
	int retv;

	while(1){
		retv = pcap_loop(pcap_fp, 0, npcap_handler, stream_recv_arg);
		if(retv==PCAP_ERROR){
			printf("pcap_loop failed! %s\n", pcap_geterr(pcap_fp));
			break;
		}
	}

	return 0;
}


/******************************************************************************/


static int rawio_send_icmp(void *tunnel, const void *buf, int len)
{
	int retv;
	char sbuf[1536];
	rawio_tunnel *tun = (rawio_tunnel *)tunnel;
	icmp_hdr *icmp = (icmp_hdr*)sbuf;

	icmp->code = 0x70;
	icmp->id   = tun->id;
	icmp->seq  = tun->seq;
	if(gcfg.run_mode==CLIENT_MODE){
		icmp->type = CLIENT_PROTO;
		tun->seq += 1;
	}else{
		icmp->type = SERVER_PROTO;
	}

	memcpy(sbuf+8, buf, len);

	icmp->cksum = 0;
	icmp->cksum = ip_checksum((u16*)sbuf, len+8);

	retv = sendto(rawio_socket, (const char*)sbuf, len+8, 0, (struct sockaddr *)&tun->baddr, sizeof(tun->baddr));
	if(retv<=0){
		printf("Error sending icmp packet: %d\n", WSAGetLastError());
		return -1;
	}

	return 0;
}


static int rawio_send_udp(void *tunnel, const void *buf, int len)
{
	rawio_tunnel *tun = (rawio_tunnel *)tunnel;
	int retv;

	if(buf==NULL){
		buf = &retv;
		len = 4;
	}

	retv = sendto(rawio_socket, (const char*)buf, len, 0, (struct sockaddr *)&tun->baddr, sizeof(tun->baddr));
	if(retv<=0){
		printf("Error sending udp packet: %d\n", WSAGetLastError());
		return -1;
	}

	return 0;
}


int rawio_send(void *tunnel, const void *buf, int len)
{
	if(buf==NULL){
		buf = magic_pkt;
		len = sizeof(magic_pkt);
	}

	if(gcfg.raw_mode==RAWMODE_ICMP){
		return rawio_send_icmp(tunnel, buf, len);
	}else if(gcfg.raw_mode==RAWMODE_UDP){
		return rawio_send_udp(tunnel, buf, len);
	}else{
		return -1;
	}
}


/******************************************************************************/


static void check_tunnel(void)
{
	int i;

	if(client_tun && client_state){
		client_tun->timeout += 1;
		if(client_tun->timeout >= TUNNEL_TIMEOUT){
			printf("%s tunnel: remote close.\n\n", rawio_tunnel_type());
			client_state = 0;
		}
	}

	for(i=0; i<MAX_TUNNEL; i++){
		if(server_tun[i]==NULL)
			continue;
		server_tun[i]->timeout += 1;
		if(server_tun[i]->timeout >= TUNNEL_TIMEOUT){
			close_tunnel(server_tun[i]);
			server_tun[i] = NULL;
		}
	}

}


static int hb_timer = 0;

int rawio_heartbeat(int ms)
{
	//printf("rawio_heartbeat!\n");

	hb_timer += ms;
	if(hb_timer>1000){
		hb_timer = 0;
		if(gcfg.run_mode == CLIENT_MODE){
			// 保证1s发一个心跳包
			rawio_send(client_tun, NULL, 0);
		}

		check_tunnel();
	}

	return 0;
}


/******************************************************************************/


int rawio_init(void *recv_cb, void *arg)
{
	int i, retv;

	memset(server_tun, 0, sizeof(server_tun));
	client_tun = NULL;
	listen_tun = NULL;
	client_state = 0;

	stream_recv_cb = recv_cb;
	stream_recv_arg = arg;

	// rawio_socket
	struct sockaddr_in bind_addr;

	if(gcfg.raw_mode==RAWMODE_ICMP){
		rawio_socket = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
	}else if(gcfg.raw_mode==RAWMODE_UDP){
		rawio_socket = socket(AF_INET, SOCK_DGRAM, 0);
	}else{
		printf("FAKETCP not support!\n");
		return -1;
	}
	if (rawio_socket==INVALID_SOCKET) {
		printf("%s socket create error! %d\n", rawio_tunnel_type(), WSAGetLastError());
		return -1;
	}

	retv = 2*1024*1024;
	setsockopt (rawio_socket, SOL_SOCKET, SO_SNDBUF, (char*)&retv, sizeof(retv));
	setsockopt (rawio_socket, SOL_SOCKET, SO_RCVBUF, (char*)&retv, sizeof(retv));

	socket_nonblock(rawio_socket);

	if(gcfg.run_mode==SERVER_MODE && gcfg.raw_mode!=RAWMODE_ICMP){
		// 服务端的TCP与UDP需要绑定监听端口
		bind_addr.sin_family = AF_INET;
		bind_addr.sin_port = htons(gcfg.target_port);
		bind_addr.sin_addr.s_addr = INADDR_ANY;
		retv = bind(rawio_socket, (struct sockaddr *)&bind_addr, sizeof(bind_addr));
		if(retv){
			printf("bind failed! %d\n", WSAGetLastError());
			return -1;
		}
	}

	// ICMP使用pcap收包
	if(gcfg.raw_mode==RAWMODE_ICMP){
		retv = npcap_init(gcfg.devname);
		if(retv)
			return retv;

		pthread_create(&recv_tid, NULL, npcap_recv_thread, NULL);
	}else if(gcfg.raw_mode==RAWMODE_UDP){
		evio_add(rawio_socket, handle_udp_recv, NULL, NULL, NULL);
		evio_enable(rawio_socket, EV_IN);
	}


	if(gcfg.run_mode==CLIENT_MODE){
		// 客户端只需要一条通道. 需要提前建立并连接到服务端.
		int tun_id = (gcfg.raw_mode==RAWMODE_ICMP) ? (gcfg.listen_port+gcfg.forward_port) : gcfg.target_port;
		client_tun = new_rawio_tunnel(inet_addr(gcfg.target_ip), tun_id);
		if(client_tun==NULL)
			return -1;
		listen_tun = client_tun;

		for(i=0; i<8; i++){
			rawio_send(client_tun, NULL, 0);
			//usleep(100000);
		}
	}

	return 0;
}

